{
  "cells": [
    {
      "cell_type": "code",
      "id": "48bce2a7-e933-4dd8-a952-9ab940920ff2",
      "metadata": {},
      "outputs": [],
      "source": [
        "import sys\n",
        "import os\n",
        "\n",
        "import sklearn\n",
        "\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        ""
      ],
      "execution_time": "0.68 s",
      "execution_count": 2
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "4cc691d4-7486-4756-b770-2d0da5a383eb",
      "kernelspec": {
        "display_name": "Python 3",
        "language": "python",
        "name": "python3"
      },
      "outputs": [],
      "source": [
        "r1 \u003d torch.tensor([1, 0, 0, 0])\n",
        "r2 \u003d torch.tensor([0, 1, 0, 0])\n",
        "r3 \u003d torch.tensor([0, 0, 1, 0])\n",
        "r4 \u003d torch.tensor([0, 0, 0, 1])"
      ],
      "execution_time": "0.01 s"
    },
    {
      "cell_type": "code",
      "id": "94cb9896-bcaa-415b-ab8f-232cc1221a84",
      "kernelspec": {
        "display_name": "Python 3",
        "language": "python",
        "name": "python3"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "metadata": {},
          "execution_count": 4,
          "data": {
            "text/plain": [
              "tensor([[[-82.4494, -82.1640]]], grad_fn\u003d\u003cSqueezeBackward1\u003e)"
            ]
          }
        }
      ],
      "source": [
        "d1 \u003d torch.nn.Linear(100, 50) # y \u003d A(100x50) * x + b,\n",
        "input \u003d torch.rand(140, 100)\n",
        "out \u003d d1(input)\n",
        "out.shape\n",
        "\n",
        "\n",
        "d2 \u003d torch.nn.Linear(4, 5)\n",
        "input_matrix \u003d torch.tensor([[1.0, 0.0, 0.0, 0.0]])\n",
        "F.relu(d2(input_matrix))\n",
        "\n",
        "d3 \u003d torch.nn.Conv1d(4, 1, 3)\n",
        "\n",
        "time_input \u003d torch.tensor([[\n",
        "    [323.0, 324.0, 323.0, 325.0,], # tick\n",
        "    [310.0, 310.0, 310.0, 311.0,], # min\n",
        "    [300.0, 300.0, 300.0, 300.0],  # hour\n",
        "    [300.0, 300.0, 300.0, 300.0],  # hour\n",
        "]])\n",
        "\n",
        "d4 \u003d torch.nn.Conv2d(4, 1, 3)\n",
        "\n",
        "d3(time_input)"
      ],
      "execution_count": 4,
      "execution_time": "1.00 s"
    },
    {
      "cell_type": "code",
      "id": "aa0c457a-1506-4bdd-b4cd-6c4cd2811513",
      "kernelspec": {
        "display_name": "Python 3",
        "language": "python",
        "name": "python3"
      },
      "outputs": [],
      "source": [
        "metric_num \u003d 15\n",
        "step_num \u003d 20\n",
        "time_scales \u003d len([\u0027tick\u0027, \u0027min\u0027, \u00275min\u0027, \u002715min\u0027, \u0027hour\u0027, \u0027day\u0027])\n",
        "action_count \u003d len([\u0027b1\u0027, \u0027s1\u0027, \u0027h\u0027, \u0027c\u0027])\n",
        "\n",
        "fake_tick \u003d torch.randn((time_scales, step_num, metric_num)) #\n",
        "fake_tick"
      ]
    },
    {
      "cell_type": "code",
      "id": "e010ff3d-c740-42fe-9af4-b0b08a4eee8b",
      "kernelspec": {
        "display_name": "Python 3",
        "language": "python",
        "name": "python3"
      },
      "outputs": [],
      "source": [
        "#%%\n",
        "#%%\n",
        "\n",
        "conv_out_rows \u003d 10\n",
        "\n",
        "mod \u003d torch.nn.Sequential(\n",
        "    torch.nn.Conv1d(step_num, conv_out_rows, 4),\n",
        "    torch.nn.Flatten(end_dim\u003d1)\n",
        ")\n",
        "out \u003d mod(fake_tick)\n",
        "out.size()\n",
        "\n",
        "#%%\n",
        "\n",
        "out.reshape(-1).size()\n",
        "\n",
        "fake_tick.size()\n",
        "t1 \u003d torch.nn.Conv2d(time_scales, 9, (5, 3))\n",
        "t1(fake_tick.unsqueeze(0)).view(-1).size()\n",
        "\n",
        "\n",
        "#%%\n",
        "\n",
        "kernel_h \u003d 6 # segment_range\n",
        "kernel_w \u003d 3 # metric_kinds\n",
        "\n",
        "cons \u003d torch.nn.Sequential(\n",
        "    torch.nn.Conv2d(time_scales, time_scales, (kernel_h, kernel_w), padding\u003d1),\n",
        "    torch.nn.ReLU(),\n",
        "    torch.nn.Conv2d(time_scales, time_scales, (kernel_h, kernel_w), padding\u003d1),\n",
        "    torch.nn.ReLU(),\n",
        "    torch.nn.Conv2d(time_scales, time_scales, (kernel_h, kernel_w), padding\u003d1),\n",
        "    torch.nn.ReLU()\n",
        ")\n",
        "\n",
        "cons(fake_tick.unsqueeze(0)).size()\n",
        "#fake_tick.unsqueeze(0).size()\n",
        "in_w \u003d len(cons(fake_tick.unsqueeze(0)).view(-1))\n",
        "\n",
        "lins \u003d torch.nn.Sequential(\n",
        "    torch.nn.Linear(in_w, in_w),\n",
        "    torch.nn.Linear(in_w, in_w),\n",
        "    torch.nn.Linear(in_w, action_count),\n",
        ")\n",
        "\n",
        "def forward(x):\n",
        "    con_out \u003d cons(x.unsqueeze(0))\n",
        "    lin_out \u003d lins(con_out.view(-1))\n",
        "    return lin_out, torch.nn.functional.softmax(lin_out)\n",
        "\n",
        "forward(fake_tick)\n",
        "\n",
        "#%%\n",
        "\n",
        "t1 \u003d torch.tensor([0.1, 0.4, 0.3, 0.5])\n",
        "torch.argmax(t1)\n",
        "\n",
        "#%%\n",
        "\n",
        "r1 \u003d torch.tensor([[[7.0, 22.0, 14.0, 29.0, 49.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 786.0, 430.0, 3471.0, 3994.0, 960782.0], [7.0, 22.0, 14.0, 29.0, 50.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 786.0, 433.0, 3471.0, 3994.0, 960782.0], [7.0, 22.0, 14.0, 29.0, 50.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 789.0, 430.0, 3471.0, 3994.0, 960786.0], [7.0, 22.0, 14.0, 29.0, 51.0, 3.0, 3790.000000000001, 3790.000000000001, 3790.000000000001, 3790.000000000001, 790.0, 428.0, 3471.0, 3994.0, 960788.0], [7.0, 22.0, 14.0, 29.0, 51.0, 3.0, 3790.000000000001, 3790.000000000001, 3790.000000000001, 3790.000000000001, 811.0, 86.0, 3471.0, 3994.0, 960788.0], [7.0, 22.0, 14.0, 29.0, 52.0, 3.0, 3790.000000000001, 3790.000000000001, 3790.000000000001, 3790.000000000001, 812.0, 63.0, 3471.0, 3994.0, 960788.0], [7.0, 22.0, 14.0, 29.0, 52.0, 3.0, 3790.000000000001, 3790.000000000001, 3790.000000000001, 3790.000000000001, 802.0, 65.0, 3471.0, 3994.0, 960798.0], [7.0, 22.0, 14.0, 29.0, 53.0, 3.0, 3790.000000000001, 3790.000000000001, 3790.000000000001, 3790.000000000001, 803.0, 66.0, 3471.0, 3994.0, 960798.0], [7.0, 22.0, 14.0, 29.0, 53.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 813.0, 66.0, 3471.0, 3994.0, 960799.0], [7.0, 22.0, 14.0, 29.0, 54.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 817.0, 417.0, 3471.0, 3994.0, 960814.0], [7.0, 22.0, 14.0, 29.0, 54.0, 3.0, 3790.000000000001, 3790.000000000001, 3790.000000000001, 3790.000000000001, 795.0, 441.0, 3471.0, 3994.0, 960817.0], [7.0, 22.0, 14.0, 29.0, 55.0, 3.0, 3790.000000000001, 3790.000000000001, 3790.000000000001, 3790.000000000001, 795.0, 425.0, 3471.0, 3994.0, 960817.0], [7.0, 22.0, 14.0, 29.0, 55.0, 3.0, 3790.000000000001, 3790.000000000001, 3790.000000000001, 3790.000000000001, 799.0, 565.0, 3471.0, 3994.0, 960817.0], [7.0, 22.0, 14.0, 29.0, 56.0, 3.0, 3790.000000000001, 3790.000000000001, 3790.000000000001, 3790.000000000001, 795.0, 567.0, 3471.0, 3994.0, 960821.0], [7.0, 22.0, 14.0, 29.0, 56.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 795.0, 568.0, 3471.0, 3994.0, 960822.0], [7.0, 22.0, 14.0, 29.0, 57.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 795.0, 568.0, 3471.0, 3994.0, 960822.0], [7.0, 22.0, 14.0, 29.0, 58.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 795.0, 568.0, 3471.0, 3994.0, 960822.0], [7.0, 22.0, 14.0, 29.0, 58.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 815.0, 575.0, 3471.0, 3994.0, 960822.0], [7.0, 22.0, 14.0, 29.0, 59.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 809.0, 594.0, 3471.0, 3994.0, 960829.0], [7.0, 22.0, 14.0, 29.0, 59.0, 3.0, 3791.0, 3791.0, 3791.0, 3791.0, 810.0, 594.0, 3471.0, 3994.0, 960829.0]], [[7.0, 22.0, 14.0, 10.0, 59.0, 3.0, 3794.0, 3795.000000000001, 3793.0, 3794.000000000001, 369.0, 170.0, 3471.0, 3994.0, 923444.0], [7.0, 22.0, 14.0, 11.0, 59.0, 3.0, 3794.000000000001, 3795.0, 3794.0, 3795.0, 261.0, 165.0, 3471.0, 3994.0, 924477.0], [7.0, 22.0, 14.0, 12.0, 59.0, 3.0, 3795.0, 3796.0, 3794.0, 3795.0, 90.0, 1072.0, 3471.0, 3994.0, 924883.0], [7.0, 22.0, 14.0, 13.0, 59.0, 3.0, 3796.0, 3797.000000000001, 3794.0, 3796.0, 402.0, 977.0, 3471.0, 3994.0, 926348.0], [7.0, 22.0, 14.0, 14.0, 59.0, 3.0, 3796.0, 3797.0, 3795.0, 3797.0, 55.0, 837.0, 3471.0, 3994.0, 927739.0], [7.0, 22.0, 14.0, 15.0, 59.0, 3.0, 3797.0, 3797.000000000001, 3795.0, 3797.0, 288.0, 826.0, 3471.0, 3994.0, 928454.0], [7.0, 22.0, 14.0, 16.0, 59.0, 3.0, 3796.000000000001, 3797.0, 3796.0, 3797.0, 177.0, 652.0, 3471.0, 3994.0, 929037.0], [7.0, 22.0, 14.0, 17.0, 59.0, 3.0, 3797.0, 3798.0, 3796.0, 3798.0, 5.0, 1030.0, 3471.0, 3994.0, 930132.0], [7.0, 22.0, 14.0, 18.0, 59.0, 3.0, 3797.0, 3798.0, 3796.0, 3798.0, 403.0, 359.0, 3471.0, 3994.0, 931980.0], [7.0, 22.0, 14.0, 19.0, 59.0, 3.0, 3798.0, 3799.000000000001, 3797.0, 3799.0, 468.0, 582.0, 3471.0, 3994.0, 933383.0], [7.0, 22.0, 14.0, 20.0, 59.0, 3.0, 3799.0, 3799.000000000001, 3797.0, 3799.0, 81.0, 413.0, 3471.0, 3994.0, 935689.0], [7.0, 22.0, 14.0, 21.0, 59.0, 3.0, 3799.0, 3800.0, 3797.0, 3798.0, 325.0, 26.0, 3471.0, 3994.0, 937973.0], [7.0, 22.0, 14.0, 22.0, 59.0, 3.0, 3798.0, 3799.0, 3797.0, 3797.0, 314.0, 418.0, 3471.0, 3994.0, 938826.0], [7.0, 22.0, 14.0, 23.0, 59.0, 3.0, 3797.0, 3798.0, 3794.0, 3794.0, 488.0, 11.0, 3471.0, 3994.0, 940542.0], [7.0, 22.0, 14.0, 24.0, 59.0, 3.0, 3794.0, 3794.000000000001, 3792.0, 3794.000000000001, 203.0, 609.0, 3471.0, 3994.0, 942662.0], [7.0, 22.0, 14.0, 25.0, 59.0, 3.0, 3794.000000000001, 3794.000000000001, 3788.0, 3790.0, 278.0, 125.0, 3471.0, 3994.0, 949933.0], [7.0, 22.0, 14.0, 26.0, 59.0, 3.0, 3790.0, 3791.0, 3789.0, 3790.0, 243.0, 575.0, 3471.0, 3994.0, 952140.0], [7.0, 22.0, 14.0, 27.0, 59.0, 3.0, 3790.0, 3790.0, 3787.0, 3788.0, 101.0, 296.0, 3471.0, 3994.0, 955588.0], [7.0, 22.0, 14.0, 28.0, 59.0, 3.0, 3788.0, 3791.0, 3788.0, 3790.0, 594.0, 375.0, 3471.0, 3994.0, 958904.0], [7.0, 22.0, 14.0, 29.0, 59.0, 3.0, 3790.0, 3792.000000000001, 3790.0, 3791.0, 810.0, 594.0, 3471.0, 3994.0, 960829.0]], [[7.0, 22.0, 10.0, 54.0, 59.0, 3.0, 3783.0, 3791.000000000001, 3782.0, 3790.0, 230.0, 381.0, 3471.0, 3994.0, 798760.0], [7.0, 22.0, 10.0, 59.0, 59.0, 3.0, 3790.0, 3793.0, 3788.0, 3791.0, 58.0, 379.0, 3471.0, 3994.0, 808204.0], [7.0, 22.0, 11.0, 4.0, 59.0, 3.0, 3791.0, 3792.0, 3788.0, 3792.0, 207.0, 593.0, 3471.0, 3994.0, 814095.0], [7.0, 22.0, 11.0, 9.0, 59.0, 3.0, 3792.0, 3794.0, 3791.0, 3792.0, 381.0, 335.0, 3471.0, 3994.0, 819017.0], [7.0, 22.0, 11.0, 14.0, 59.0, 3.0, 3791.0, 3802.0, 3791.0, 3796.0, 491.0, 197.0, 3471.0, 3994.0, 839485.0], [7.0, 22.0, 11.0, 19.0, 59.0, 3.0, 3796.0, 3799.0, 3793.0, 3797.0, 228.0, 213.0, 3471.0, 3994.0, 845922.0], [7.0, 22.0, 11.0, 24.0, 59.0, 3.0, 3797.0, 3802.0, 3797.0, 3801.0, 386.0, 39.0, 3471.0, 3994.0, 853868.0], [7.0, 22.0, 11.0, 29.0, 59.0, 3.0, 3801.0, 3801.000000000001, 3795.0, 3798.0, 13.0, 21.0, 3471.0, 3994.0, 859603.0], [7.0, 22.0, 13.0, 34.0, 59.0, 3.0, 3793.0, 3795.000000000001, 3787.0, 3793.0, 199.0, 215.0, 3471.0, 3994.0, 876786.0], [7.0, 22.0, 13.0, 39.0, 59.0, 3.0, 3793.0, 3799.0, 3792.0, 3796.0, 429.0, 16.0, 3471.0, 3994.0, 884447.0], [7.0, 22.0, 13.0, 44.0, 59.0, 3.0, 3795.0, 3800.0, 3789.0, 3792.0, 394.0, 273.0, 3471.0, 3994.0, 898271.0], [7.0, 22.0, 13.0, 49.0, 59.0, 3.0, 3792.0, 3795.000000000001, 3791.0, 3794.0, 302.0, 229.0, 3471.0, 3994.0, 901828.0], [7.0, 22.0, 13.0, 54.0, 59.0, 3.0, 3793.0, 3795.000000000001, 3790.0, 3791.0, 86.0, 210.0, 3471.0, 3994.0, 908915.0], [7.0, 22.0, 13.0, 59.0, 59.0, 3.0, 3791.0, 3795.000000000001, 3791.0, 3793.0, 189.0, 242.0, 3471.0, 3994.0, 912888.0], [7.0, 22.0, 14.0, 4.0, 59.0, 3.0, 3794.0, 3796.000000000001, 3792.0, 3794.0, 496.0, 133.0, 3471.0, 3994.0, 917873.0], [7.0, 22.0, 14.0, 9.0, 59.0, 3.0, 3794.0, 3796.0, 3792.0, 3794.0, 875.0, 335.0, 3471.0, 3994.0, 922538.0], [7.0, 22.0, 14.0, 14.0, 59.0, 3.0, 3794.0, 3797.000000000001, 3793.0, 3797.0, 55.0, 837.0, 3471.0, 3994.0, 927739.0], [7.0, 22.0, 14.0, 19.0, 59.0, 3.0, 3797.0, 3799.000000000001, 3795.0, 3799.0, 468.0, 582.0, 3471.0, 3994.0, 933383.0], [7.0, 22.0, 14.0, 24.0, 59.0, 3.0, 3799.0, 3800.0, 3792.0, 3794.000000000001, 203.0, 609.0, 3471.0, 3994.0, 942662.0], [7.0, 22.0, 14.0, 29.0, 59.0, 3.0, 3794.000000000001, 3794.000000000001, 3787.0, 3791.0, 810.0, 594.0, 3471.0, 3994.0, 960829.0]], [[7.0, 22.0, 21.0, 29.0, 59.0, 3.0, 3790.0, 3792.000000000001, 3784.0, 3789.0, 491.0, 18.0, 3471.0, 3994.0, 239962.0], [7.0, 22.0, 21.0, 44.0, 59.0, 3.0, 3789.0, 3791.0, 3773.0, 3778.000000000001, 177.0, 398.0, 3471.0, 3994.0, 293251.0], [7.0, 22.0, 21.0, 59.0, 59.0, 3.0, 3778.000000000001, 3792.0, 3777.0, 3790.0, 392.0, 552.0, 3471.0, 3994.0, 344569.0], [7.0, 22.0, 22.0, 14.0, 59.0, 3.0, 3789.0, 3792.000000000001, 3787.0, 3791.0, 437.0, 414.0, 3471.0, 3994.0, 367645.0], [7.0, 22.0, 22.0, 29.0, 59.0, 3.0, 3791.0, 3795.000000000001, 3785.0, 3786.0, 307.0, 284.0, 3471.0, 3994.0, 400256.0], [7.0, 22.0, 22.0, 44.0, 59.0, 3.0, 3786.0, 3792.000000000001, 3786.0, 3790.0, 44.0, 884.0, 3471.0, 3994.0, 418239.0], [7.0, 22.0, 22.0, 59.0, 59.0, 3.0, 3790.0, 3794.0, 3787.0, 3790.0, 11.0, 15.0, 3471.0, 3994.0, 442464.0], [7.0, 22.0, 9.0, 14.0, 59.0, 3.0, 3794.0, 3816.0, 3790.0, 3806.0, 506.0, 65.0, 3471.0, 3994.0, 561347.0], [7.0, 22.0, 9.0, 29.0, 59.0, 3.0, 3807.0, 3808.0, 3792.0, 3801.0, 460.0, 360.0, 3471.0, 3994.0, 622785.0], [7.0, 22.0, 9.0, 44.0, 59.0, 3.0, 3801.0, 3803.000000000001, 3794.0, 3801.0, 27.0, 319.0, 3471.0, 3994.0, 661898.0], [7.0, 22.0, 9.0, 59.0, 59.0, 3.0, 3800.0, 3802.000000000001, 3794.0, 3797.0, 24.0, 358.0, 3471.0, 3994.0, 682452.0], [7.0, 22.0, 10.0, 14.0, 59.0, 3.0, 3797.0, 3800.000000000001, 3794.0, 3795.0, 125.0, 320.0, 3471.0, 3994.0, 703059.0], [7.0, 22.0, 10.0, 44.0, 59.0, 3.0, 3796.0, 3797.000000000001, 3782.0, 3787.0, 177.0, 321.0, 3471.0, 3994.0, 775762.0], [7.0, 22.0, 10.0, 59.0, 59.0, 3.0, 3787.0, 3793.0, 3782.0, 3791.0, 58.0, 379.0, 3471.0, 3994.0, 808204.0], [7.0, 22.0, 11.0, 14.0, 59.0, 3.0, 3791.0, 3802.0, 3788.0, 3796.0, 491.0, 197.0, 3471.0, 3994.0, 839485.0], [7.0, 22.0, 11.0, 29.0, 59.0, 3.0, 3796.0, 3802.0, 3793.0, 3798.0, 13.0, 21.0, 3471.0, 3994.0, 859603.0], [7.0, 22.0, 13.0, 44.0, 59.0, 3.0, 3793.0, 3800.0, 3787.0, 3792.0, 394.0, 273.0, 3471.0, 3994.0, 898271.0], [7.0, 22.0, 13.0, 59.0, 59.0, 3.0, 3792.0, 3795.000000000001, 3790.0, 3793.0, 189.0, 242.0, 3471.0, 3994.0, 912888.0], [7.0, 22.0, 14.0, 14.0, 59.0, 3.0, 3794.0, 3797.000000000001, 3792.0, 3797.0, 55.0, 837.0, 3471.0, 3994.0, 927739.0], [7.0, 22.0, 14.0, 29.0, 59.0, 3.0, 3797.0, 3800.0, 3787.0, 3791.0, 810.0, 594.0, 3471.0, 3994.0, 960829.0]], [[7.0, 20.0, 22.0, 59.0, 59.0, 1.0, 3744.0, 3749.0, 3723.0, 3734.0, 37.0, 171.0, 3451.0, 3970.0, 391180.0], [7.0, 20.0, 9.0, 59.0, 59.0, 1.0, 3735.0, 3735.0, 3708.0, 3709.0, 550.0, 595.0, 3451.0, 3970.0, 593648.0], [7.0, 20.0, 10.0, 59.0, 59.0, 1.0, 3708.000000000001, 3718.0000000000005, 3701.0, 3716.0, 205.0, 812.0, 3451.0, 3970.0, 713200.0], [7.0, 20.0, 11.0, 29.0, 59.0, 1.0, 3716.0, 3717.000000000001, 3701.0, 3701.0, 758.0, 177.0, 3451.0, 3970.0, 790258.0], [7.0, 20.0, 13.0, 59.0, 59.0, 1.0, 3701.0000000000005, 3711.0, 3696.000000000001, 3701.0, 7.0, 501.0, 3451.0, 3970.0, 893068.0], [7.0, 20.0, 14.0, 59.0, 59.0, 1.0, 3700.0, 3712.0, 3698.0, 3702.000000000001, 75.0, 40.0, 3451.0, 3970.0, 1026156.0], [7.0, 21.0, 21.0, 59.0, 59.0, 2.0, 3699.0, 3712.0, 3697.0, 3705.0, 102.0, 634.0, 3458.0, 3979.0, 142261.0], [7.0, 21.0, 22.0, 59.0, 59.0, 2.0, 3705.0, 3720.000000000001, 3705.0, 3716.0, 919.0, 198.0, 3458.0, 3979.0, 271396.0], [7.0, 21.0, 9.0, 59.0, 59.0, 2.0, 3716.0, 3745.000000000001, 3716.0, 3738.0, 304.0, 161.0, 3458.0, 3979.0, 579003.0], [7.0, 21.0, 10.0, 59.0, 59.0, 2.0, 3738.0, 3741.000000000001, 3728.0, 3733.0, 257.0, 71.0, 3458.0, 3979.0, 658066.0], [7.0, 21.0, 11.0, 29.0, 59.0, 2.0, 3733.0, 3739.000000000001, 3731.0, 3735.0, 136.0, 329.0, 3458.0, 3979.0, 692400.0], [7.0, 21.0, 13.0, 59.0, 59.0, 2.0, 3733.0, 3743.000000000001, 3731.0, 3741.0, 318.0, 753.0, 3458.0, 3979.0, 740114.0], [7.0, 21.0, 14.0, 59.0, 59.0, 2.0, 3742.0, 3765.000000000001, 3739.0, 3764.000000000001, 323.0, 390.0, 3458.0, 3979.0, 1025122.0], [7.0, 22.0, 21.0, 59.0, 59.0, 3.0, 3768.0, 3795.0, 3768.0, 3790.0, 392.0, 552.0, 3471.0, 3994.0, 344569.0], [7.0, 22.0, 22.0, 59.0, 59.0, 3.0, 3789.0, 3795.000000000001, 3785.0, 3790.0, 11.0, 15.0, 3471.0, 3994.0, 442464.0], [7.0, 22.0, 9.0, 59.0, 59.0, 3.0, 3794.0, 3816.0, 3790.0, 3797.0, 24.0, 358.0, 3471.0, 3994.0, 682452.0], [7.0, 22.0, 10.0, 59.0, 59.0, 3.0, 3797.0, 3800.000000000001, 3782.0, 3791.0, 58.0, 379.0, 3471.0, 3994.0, 808204.0], [7.0, 22.0, 11.0, 29.0, 59.0, 3.0, 3791.0, 3802.0, 3788.0, 3798.0, 13.0, 21.0, 3471.0, 3994.0, 859603.0], [7.0, 22.0, 13.0, 59.0, 59.0, 3.0, 3793.0, 3800.0, 3787.0, 3793.0, 189.0, 242.0, 3471.0, 3994.0, 912888.0], [7.0, 22.0, 14.0, 29.0, 59.0, 3.0, 3794.0, 3800.0, 3787.0, 3791.0, 810.0, 594.0, 3471.0, 3994.0, 960829.0]], [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 1.0, 14.0, 59.0, 59.0, 3.0, 3566.0000000000005, 3574.0000000000005, 3542.0, 3567.0, 27.0, 685.0, 3318.0, 3817.0, 1086112.0], [7.0, 2.0, 14.0, 59.0, 59.0, 4.0, 3566.0000000000005, 3581.0000000000005, 3560.0, 3574.0, 46.0, 403.0, 3311.0, 3810.0, 809867.0], [7.0, 3.0, 14.0, 59.0, 59.0, 5.0, 3577.0, 3620.0000000000005, 3574.0, 3619.0, 618.0, 63.0, 3319.0, 3818.0, 1233967.0], [7.0, 6.0, 14.0, 59.0, 59.0, 1.0, 3622.0, 3629.0000000000005, 3603.0, 3618.0, 27.0, 38.0, 3350.0, 3855.0, 949955.0], [7.0, 7.0, 14.0, 59.0, 59.0, 2.0, 3619.0, 3654.0, 3619.0, 3634.0, 32.0, 311.0, 3361.0, 3868.0, 957200.0], [7.0, 8.0, 14.0, 59.0, 59.0, 3.0, 3638.0, 3705.000000000001, 3636.0, 3700.0, 265.0, 7.0, 3381.0, 3890.0, 1368038.0], [7.0, 9.0, 14.0, 59.0, 59.0, 4.0, 3708.0, 3742.000000000001, 3703.0, 3725.0, 40.0, 8.0, 3420.0, 3935.0, 1172490.0], [7.0, 10.0, 14.0, 59.0, 59.0, 5.0, 3723.0, 3723.0, 3680.0, 3690.0, 412.0, 7.0, 3459.0, 3980.0, 1178729.0], [7.0, 13.0, 14.0, 59.0, 59.0, 1.0, 3692.0, 3768.0, 3692.0, 3739.0, 76.0, 364.0, 3443.0, 3962.0, 1377089.0], [7.0, 14.0, 14.0, 59.0, 59.0, 2.0, 3741.0, 3754.000000000001, 3716.0, 3744.0, 67.0, 59.0, 3473.0, 3996.0, 953880.0], [7.0, 15.0, 14.0, 59.0, 59.0, 3.0, 3743.0, 3762.0, 3721.0, 3746.0, 7.0, 5.0, 3473.0, 3996.0, 858731.0], [7.0, 16.0, 14.0, 59.0, 59.0, 4.0, 3739.0, 3743.000000000001, 3692.0, 3702.0, 28.0, 192.0, 3477.0, 4000.0, 1255797.0], [7.0, 17.0, 14.0, 59.0, 59.0, 5.0, 3706.0, 3734.000000000001, 3686.0, 3726.0, 1859.0, 24.0, 3458.0, 3979.0, 926244.0], [7.0, 20.0, 14.0, 59.0, 59.0, 1.0, 3721.0, 3749.0, 3696.000000000001, 3702.000000000001, 75.0, 40.0, 3451.0, 3970.0, 1026156.0], [7.0, 21.0, 14.0, 59.0, 59.0, 2.0, 3699.0, 3765.000000000001, 3697.0, 3764.000000000001, 323.0, 390.0, 3458.0, 3979.0, 1025122.0], [7.0, 22.0, 14.0, 29.0, 59.0, 3.0, 3768.0, 3816.0, 3768.0, 3791.0, 810.0, 594.0, 3471.0, 3994.0, 960829.0]]])\n",
        "r1.size()\n",
        "\n",
        "#%%\n",
        "\n",
        "forward(r1)\n",
        "\n",
        "#%%\n",
        "\n",
        "import torch\n",
        "\n",
        "\n",
        "class TestNN(torch.nn.Module):\n",
        "    def __init__(self):\n",
        "        super(TestNN, self).__init__()\n",
        "        self.cons \u003d torch.nn.Sequential(\n",
        "            torch.nn.Conv2d(time_scales, time_scales, (kernel_h, kernel_w), padding\u003d1),\n",
        "            torch.nn.ReLU(),\n",
        "            torch.nn.Conv2d(time_scales, time_scales, (kernel_h, kernel_w), padding\u003d1),\n",
        "            torch.nn.ReLU(),\n",
        "            torch.nn.Conv2d(time_scales, time_scales, (kernel_h, kernel_w), padding\u003d1),\n",
        "            torch.nn.ReLU())\n",
        "        #cons(fake_tick.unsqueeze(0)).size()\n",
        "        #fake_tick.unsqueeze(0).size()\n",
        "        in_w \u003d len(cons(fake_tick.unsqueeze(0)).view(-1))\n",
        "        self.lins \u003d torch.nn.Sequential(\n",
        "            torch.nn.Linear(in_w, in_w),\n",
        "            torch.nn.Linear(in_w, in_w),\n",
        "            torch.nn.Linear(in_w, action_count))\n",
        "    def forward(self, x):\n",
        "        con_out \u003d self.cons(x.unsqueeze(0))\n",
        "        lin_out \u003d self.lins(con_out.view(-1))\n",
        "        return lin_out, torch.nn.functional.softmax(lin_out)\n",
        "\n",
        "tt \u003d TestNN()\n",
        "torch.onnx.export(tt, r1, \"test.onnx\", opset_version\u003d11)\n",
        ""
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "1635730354424",
      "identity": "99b7ff07-9f9d-4678-acd7-43f73177cc70"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.8"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}